from sentence_transformers import SentenceTransformer
import torch
import json
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import argparse
from heads import get_matching_head
import re

class MatchingInference:
    def __init__(self, model_dir):
        self.embedding_model = SentenceTransformer(f"{model_dir}/embedding_model", trust_remote_code=True)
        self.embedding_model = self.embedding_model.cuda()

        embedding_dim = self.embedding_model.get_sentence_embedding_dimension()

        candidate_head_types = ["base", "deep_mlp", "cos_sim", "residual", "cross_attn", "feature", "cos_sim_deeper"]
        # candidate_head_types = ["deep_mlp", "cos_sim", "cross_attn"]

        basename = os.path.basename(model_dir)

        # head_type = "cos_sim"

        basename = model_dir
        head_type = None
        for candidate in candidate_head_types:
            if candidate in basename:
                head_type = candidate
                break

        if head_type is None:
            raise ValueError(f"Could not detect valid head_type from model_dir name: {basename}")

        print(f"Detected head_type: {head_type}")
        self.matching_head = get_matching_head(head_type, embedding_dim)
        self.matching_head.load_state_dict(torch.load(f"{model_dir}/matching_head.pt"))
        self.matching_head = self.matching_head.cuda()
        self.matching_head.eval()

        self.embedding_model.eval()

    @torch.no_grad()
    def predict_batch(self, answers, reasons, batch_size=64):
        assert len(answers) == len(reasons)

        all_probs = []
        for idx in range(0, len(answers), batch_size):
            batch_answers = answers[idx:idx+batch_size]
            batch_reasons = reasons[idx:idx+batch_size]

            emb_a = self.embedding_model.encode(batch_answers, convert_to_tensor=True, normalize_embeddings=True)
            emb_b = self.embedding_model.encode(batch_reasons, convert_to_tensor=True, normalize_embeddings=True)

            features = {"embedding_a": emb_a, "embedding_b": emb_b}
            outputs = self.matching_head(features)
            logits = outputs["logits"].squeeze(-1)
            probs = torch.sigmoid(logits)

            all_probs.extend(probs.tolist())

        return all_probs

    @torch.no_grad()
    def evaluate(self, test_file, batch_size=64, threshold=0.5, save_dir=None):
        with open(test_file, "r", encoding="utf-8") as f:
            data = json.load(f)

        answers = [item["answer"] for item in data]
        reasons = [item["reason"] for item in data]
        labels = [item["label"] for item in data]

        probs = self.predict_batch(answers, reasons, batch_size=batch_size)
        preds = [int(p >= threshold) for p in probs]

        correct = sum([int(p == l) for p, l in zip(preds, labels)])
        total = len(labels)
        acc = correct / total

        log_text = ""
        log_text += f"\nTest Accuracy: {acc*100:.2f}% ({correct}/{total})\n"
        log_text += "\nClassification Report:\n"
        log_text += classification_report(labels, preds, digits=4)

        cm = confusion_matrix(labels, preds)
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
        elif cm.shape == (1, 1):
            tn = cm[0, 0] if labels[0] == 0 else 0
            fp = fn = tp = 0
        elif cm.shape == (1, 2):
            tn, fp = cm[0]
            fn = tp = 0
        elif cm.shape == (2, 1):
            fn = cm[1, 0]
            tn = fp = tp = 0
        else:
            raise ValueError(f"Unexpected confusion matrix shape: {cm.shape}")

        log_text += "\n\nConfusion Matrix:\n"
        log_text += f"True Positive (TP): {tp}\n"
        log_text += f"False Positive (FP): {fp}\n"
        log_text += f"False Negative (FN): {fn}\n"
        log_text += f"True Negative (TN): {tn}\n"

        print(log_text)

        if save_dir is not None:
            os.makedirs(save_dir, exist_ok=True)

            self._plot_confusion_matrix(cm, save_path=os.path.join(save_dir, "confusion_matrix.png"))

            error_samples = []
            for ans, reason, label, pred, prob in zip(answers, reasons, labels, preds, probs):
                if pred != label:
                    error_samples.append({
                        "answer": ans,
                        "reason": reason,
                        "label": label,
                        "pred": pred,
                        "probability": prob
                    })
            error_path = os.path.join(save_dir, "error_samples.json")
            with open(error_path, "w", encoding="utf-8") as f:
                json.dump(error_samples, f, indent=2, ensure_ascii=False)

            log_path = os.path.join(save_dir, "evaluation_log.txt")
            with open(log_path, "w", encoding="utf-8") as f:
                f.write(log_text)

            print(f"\nConfusion matrix, error samples, and evaluation log saved to: {save_dir}")

        return acc

    def _plot_confusion_matrix(self, cm, save_path=None):
        plt.figure(figsize=(6,5))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Pred 0", "Pred 1"], yticklabels=["True 0", "True 1"])
        plt.xlabel("Prediction")
        plt.ylabel("Ground Truth")
        plt.title("Confusion Matrix")
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
            plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, required=True)
    parser.add_argument("--test_files", type=str, nargs="+", required=True, 
                        help="List of test file paths, separated by space.")
    parser.add_argument("--save_dir", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--threshold", type=float, default=0.5)
    parser.add_argument("--overwrite", action="store_true",
                        help="If set, will overwrite existing results in save_dir.")

    args = parser.parse_args()

    matcher = MatchingInference(args.model_dir)

    for test_file in args.test_files:

        parent_dir = os.path.basename(os.path.dirname(test_file))
        file_stem = os.path.splitext(os.path.basename(test_file))[0]
        dataset_name = f"{parent_dir}_{file_stem}"

        dataset_save_dir = os.path.join(args.save_dir, dataset_name)

        existing_files = [
            "confusion_matrix.png", "error_samples.json", "evaluation_log.txt"
        ]
        if not args.overwrite and all(os.path.exists(os.path.join(dataset_save_dir, f)) for f in existing_files):
            print(f"Skipping {dataset_name}: results already exist in {dataset_save_dir}")
            continue

        print(f"\n=== Evaluating on {dataset_name} ===")
        matcher.evaluate(
            test_file=test_file,
            batch_size=args.batch_size,
            threshold=args.threshold,
            save_dir=dataset_save_dir
        )
